#!/usr/bin/env python3
import numpy as np
from scipy.stats import wasserstein_distance
from . import NewStateDetector
import torch

class DiscreteSigma(NewStateDetector):
    def __init__(
        self, horizon=3000, delta=0.02, state_precision=4, beta=0.1, debug=True
    ):
        # self._state_hash_table=state_hash_table
        self._horizon = horizon
        self._delta = delta
        self._state_precision = state_precision
        self._state_map = dict()
        self._debug = debug
        self._beta = beta
        pass

    # In practice, this region should be gloable
    # hash table for states_in_buffer should not be re-calculate every time
    def discrete_sigma(self, states_in_buffer, state_incoming):

        # state_map = dict()
        for _state in states_in_buffer:
            # exit()
            _state_str = np.array2string(_state, precision=self._state_precision)
            self._state_map[_state_str] = self._state_map.get(_state_str, 0) + 1

        state_tmp = np.array2string(state_incoming, precision=self._state_precision)
        N_k_S_t = self._state_map.get(state_tmp, 0)
        self._state_map[state_tmp] = N_k_S_t + 1

        if N_k_S_t == 0:
            N_k_S_t = N_k_S_t + 1

        if self._debug:
            print(
                self._beta
                * np.sqrt(
                    2 * np.square(self._horizon) * np.log(2 / self._delta) / N_k_S_t
                )
            )

        return self._beta * np.sqrt(
            2 * np.square(self._horizon) * np.log(2 / self._delta) / N_k_S_t
        )

    def evaluate(self, state: np.ndarray):
        """Evaluate the new-state detector on a single state

        Returns
            - a score (float) that represents the novelity of the state
        """

        return self.discrete_sigma(self._buffer, state)


class EuclideanDistance(NewStateDetector):
    def __init__(self, min_n=2, threshold_accept=0, option="mean_dist", debug=False):
        self._min_n = min_n
        self._threshold_accept = threshold_accept
        self._debug = debug
        self._option = option  # {mean_dist, min_dist, mean_min_n_dist, gap_accept_dist}
        # self._states_in_buffer = states_in_buffer

    # threshold_accept: acceptable min distance, 0: means exactly the same as input
    # min_n: only measure the min n state
    def calculate_euclidean_distance(self, states_in_buffer, state_incoming):

        if isinstance(states_in_buffer, torch.Tensor):
            states_in_buffer = states_in_buffer.cpu().detach().numpy()
        else:
            states_in_buffer= states_in_buffer

        if isinstance(state_incoming, torch.Tensor):
            state_incoming = state_incoming.cpu().detach().numpy()
        else:
            state_incoming= state_incoming

        # state_incoming = state_incoming.cpu().detach().numpy()

        min_n = self._min_n
        threshold_accept = self._threshold_accept
        buffer_size = len(states_in_buffer)

        dist_arr = []
        dist_gap_arr = []

        for i in range(buffer_size):

            tmp = np.linalg.norm(states_in_buffer[i]["state"] - state_incoming[0])
            dist_arr.append(tmp)
            if tmp <= threshold_accept:
                dist_gap_arr.append(tmp)

        dist_arr_min_n = np.sort(dist_arr, axis=None)[0:min_n]
        gap_accept_n = len(dist_gap_arr)

        if gap_accept_n == 0:
            gap_accept_dist = np.nan
        else:
            gap_accept_dist = np.mean(dist_gap_arr)

        # return np.mean(dist_arr),np.min(dist_arr),np.mean(dist_arr_min_n),gap_accept_dist,gap_accept_n
        return (
            np.round_(np.mean(dist_arr), decimals=2),
            np.round_(np.min(dist_arr), decimals=2),
            np.round_(np.mean(dist_arr_min_n), decimals=2),
            gap_accept_dist,
            gap_accept_n,
        )

    def evaluate(self, states_in_buffer, state: np.ndarray):
        """Evaluate the new-state detector on a single state

        Returns
            - a score (float) that represents the novelity of the state
        """

        (
            mean_dist,
            min_dist,
            mean_min_n_dist,
            gap_accept_dist,
            gap_accept_len,
        ) = self.calculate_euclidean_distance(states_in_buffer, state)

        if self._debug:
            print(
                "all kind of distances for new state:",
                mean_dist,
                min_dist,
                mean_min_n_dist,
                gap_accept_dist,
            )

        if self._option == "mean_dist":
            return mean_dist
        elif self._option == "min_dist":
            return min_dist
        elif self._option == "mean_min_n_dist":
            return mean_min_n_dist
        elif self._option == "gap_accept_dist":
            return gap_accept_dist
        else:
            print("error in setting! exit.")
            exit()





class WassersteinDistance(NewStateDetector):
    def __init__(self, min_n=2, threshold_accept=0, option="mean_dist", debug=False):
        self._min_n = min_n
        self._threshold_accept = threshold_accept
        self._debug = debug
        self._option = option  # {mean_dist, min_dist, mean_min_n_dist, gap_accept_dist}
        # self._states_in_buffer = states_in_buffer

    # threshold_accept: acceptable min distance, 0: means exactly the same as input
    # min_n: only measure the min n state
    def calculate_ws_distance(self, states_in_buffer, state_incoming):

        if isinstance(states_in_buffer, torch.Tensor):
            states_in_buffer = states_in_buffer.cpu().detach().numpy()
        else:
            states_in_buffer= states_in_buffer

        if isinstance(state_incoming, torch.Tensor):
            state_incoming = state_incoming.cpu().detach().numpy()
        else:
            state_incoming= state_incoming

        # state_incoming = state_incoming.cpu().detach().numpy()

        min_n = self._min_n
        threshold_accept = self._threshold_accept
        buffer_size = len(states_in_buffer)

        dist_arr = []
        dist_gap_arr = []

        for i in range(buffer_size):

            tmp = wasserstein_distance(states_in_buffer[i]["state"], state_incoming[0])
            dist_arr.append(tmp)
            if tmp <= threshold_accept:
                dist_gap_arr.append(tmp)

        dist_arr_min_n = np.sort(dist_arr, axis=None)[0:min_n]
        gap_accept_n = len(dist_gap_arr)

        if gap_accept_n == 0:
            gap_accept_dist = np.nan
        else:
            gap_accept_dist = np.mean(dist_gap_arr)

        # return np.mean(dist_arr),np.min(dist_arr),np.mean(dist_arr_min_n),gap_accept_dist,gap_accept_n
        return (
            np.round_(np.mean(dist_arr), decimals=2),
            np.round_(np.min(dist_arr), decimals=2),
            np.round_(np.mean(dist_arr_min_n), decimals=2),
            gap_accept_dist,
            gap_accept_n,
        )

    def evaluate(self, states_in_buffer, state: np.ndarray):
        """Evaluate the new-state detector on a single state

        Returns
            - a score (float) that represents the novelity of the state
        """

        (
            mean_dist,
            min_dist,
            mean_min_n_dist,
            gap_accept_dist,
            gap_accept_len,
        ) = self.calculate_ws_distance(states_in_buffer, state)

        if self._debug:
            print(
                "all kind of distances for new state:",
                mean_dist,
                min_dist,
                mean_min_n_dist,
                gap_accept_dist,
            )

        if self._option == "mean_dist":
            return mean_dist
        elif self._option == "min_dist":
            return min_dist
        elif self._option == "mean_min_n_dist":
            return mean_min_n_dist
        elif self._option == "gap_accept_dist":
            return gap_accept_dist
        else:
            print("error in setting! exit.")
            exit()


